{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VAE MNIST example: BO in a latent space"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we use the MNIST dataset and some standard PyTorch examples to show a synthetic problem where the input to the objective function is a `28 x 28` image. The main idea is to train a [variational auto-encoder (VAE)](https://arxiv.org/abs/1312.6114) on the MNIST dataset and run Bayesian Optimization in the latent space. We also refer readers to [this tutorial](http://krasserm.github.io/2018/04/07/latent-space-optimization/), which discusses [the method](https://arxiv.org/abs/1610.02415) of jointly training a VAE with a predictor (e.g., classifier), and shows a similar tutorial for the MNIST setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets # transforms\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "dtype = torch.float"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Problem setup\n",
    "\n",
    "Let's first define our synthetic expensive-to-evaluate objective function. We assume that it takes the following form:\n",
    "\n",
    "$$\\text{image} \\longrightarrow \\text{image classifier} \\longrightarrow \\text{scoring function} \n",
    "\\longrightarrow \\text{score}.$$\n",
    "\n",
    "The classifier is a convolutional neural network (CNN) trained using the architecture of the [PyTorch CNN example](https://github.com/pytorch/examples/tree/master/mnist)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.fc1 = nn.Linear(4 * 4 * 50, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.max_pool2d(x, 2, 2)\n",
    "        x = x.view(-1, 4*4*50)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next instantiate the CNN for digit recognition and load a pre-trained model.\n",
    "\n",
    "Here, you may have to change `PRETRAINED_LOCATION` to the location of the `pretrained_models` folder on your machine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRETRAINED_LOCATION = \"./pretrained_models\"\n",
    "\n",
    "cnn_model = Net().to(device)\n",
    "cnn_state_dict = torch.load(os.path.join(PRETRAINED_LOCATION, \"mnist_cnn.pt\"), map_location=device)\n",
    "cnn_model.load_state_dict(cnn_state_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our VAE model follows the [PyTorch VAE example](https://github.com/pytorch/examples/tree/master/vae), except that we use the same data transform from the CNN tutorial for consistency. We then instantiate the model and again load a pre-trained model. To train these models, we refer readers to the PyTorch Github repository. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 400)\n",
    "        self.fc21 = nn.Linear(400, 20)\n",
    "        self.fc22 = nn.Linear(400, 20)\n",
    "        self.fc3 = nn.Linear(20, 400)\n",
    "        self.fc4 = nn.Linear(400, 784)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(0.5*logvar)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps*std\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return torch.sigmoid(self.fc4(h3))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x.view(-1, 784))\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        return self.decode(z), mu, logvar\n",
    "\n",
    "vae_model = VAE().to(device)\n",
    "vae_state_dict = torch.load(os.path.join(PRETRAINED_LOCATION, \"mnist_vae.pt\"), map_location=device)\n",
    "vae_model.load_state_dict(vae_state_dict);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now define the scoring function that maps digits to scores. The function below prefers the digit '3'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score(y):\n",
    "    \"\"\"Returns a 'score' for each digit from 0 to 9. It is modeled as a squared exponential\n",
    "    centered at the digit '3'.\n",
    "    \"\"\"\n",
    "    return torch.exp(-2 * (y - 3)**2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given the scoring function, we can now write our overall objective, which as discussed above, starts with an image and outputs a score. Let's say the objective computes the expected score given the probabilities from the classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_image_recognition(x):\n",
    "    \"\"\"The input x is an image and an expected score based on the CNN classifier and\n",
    "    the scoring function is returned.\n",
    "    \"\"\"\n",
    "    with torch.no_grad():\n",
    "        probs = torch.exp(cnn_model(x))  # b x 10\n",
    "        scores = score(torch.arange(10, device=device, dtype=dtype)).expand(probs.shape)\n",
    "    return (probs * scores).sum(dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we define a helper function `decode` that takes as input the parameters `mu` and `logvar` of the variational distribution and performs reparameterization and the decoding. We use batched Bayesian optimization to search over the parameters `mu` and `logvar`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode(train_x):\n",
    "    with torch.no_grad():\n",
    "        decoded = vae_model.decode(train_x)\n",
    "    return decoded.view(train_x.shape[0], 1, 28, 28)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Model initialization and initial random batch\n",
    "\n",
    "We use a `SingleTaskGP` to model the score of an image generated by a latent representation. The model is initialized with points drawn from $[-6, 6]^{20}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.models import SingleTaskGP\n",
    "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n",
    "\n",
    "\n",
    "bounds = torch.tensor([[-6.0] * 20, [6.0] * 20], device=device, dtype=dtype)\n",
    "\n",
    "\n",
    "def initialize_model(n=5):\n",
    "    # generate training data  \n",
    "    train_x = (bounds[1] - bounds[0]) * torch.rand(n, 20, device=device, dtype=dtype) + bounds[0]\n",
    "    train_obj = score_image_recognition(decode(train_x))\n",
    "    best_observed_value = train_obj.max().item()\n",
    "    \n",
    "    # define models for objective and constraint\n",
    "    model = SingleTaskGP(train_X=train_x, train_Y=train_obj)\n",
    "    model = model.to(train_x)\n",
    "    \n",
    "    mll = ExactMarginalLogLikelihood(model.likelihood, model)\n",
    "    mll = mll.to(train_x)\n",
    "    \n",
    "    return train_x, train_obj, mll, model, best_observed_value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Define a helper function that performs the essential BO step\n",
    "The helper function below takes an acquisition function as an argument, optimizes it, and returns the batch $\\{x_1, x_2, \\ldots x_q\\}$ along with the observed function values. For this example, we'll use a small batch of $q=3$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch.optim import joint_optimize\n",
    "\n",
    "\n",
    "BATCH_SIZE = 3\n",
    "\n",
    "\n",
    "def optimize_acqf_and_get_observation(acq_func):\n",
    "    \"\"\"Optimizes the acquisition function, and returns a new candidate and a noisy observation\"\"\"\n",
    "    \n",
    "    # optimize\n",
    "    candidates = joint_optimize(\n",
    "        acq_function=acq_func,\n",
    "        bounds=bounds,\n",
    "        q=BATCH_SIZE,\n",
    "        num_restarts=10,\n",
    "        raw_samples=200,\n",
    "    )\n",
    "\n",
    "    # observe new values \n",
    "    new_x = candidates.detach()\n",
    "    new_obj = score_image_recognition(decode(new_x))\n",
    "    return new_x, new_obj"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform Bayesian Optimization loop with qEI\n",
    "The Bayesian optimization \"loop\" for a batch size of $q$ simply iterates the following steps: (1) given a surrogate model, choose a batch of points $\\{x_1, x_2, \\ldots x_q\\}$, (2) observe $f(x)$ for each $x$ in the batch, and (3) update the surrogate model. We run `N_BATCH=75` iterations. The acquisition function is approximated using `MC_SAMPLES=2000` samples. We also initialize the model with 5 randomly drawn points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from botorch import fit_gpytorch_model\n",
    "from botorch.acquisition.monte_carlo import qExpectedImprovement\n",
    "from botorch.sampling.samplers import SobolQMCNormalSampler\n",
    "\n",
    "seed=1\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "N_BATCH = 50\n",
    "MC_SAMPLES = 2000\n",
    "best_observed = []\n",
    "\n",
    "# call helper function to initialize model\n",
    "train_x, train_obj, mll, model, best_value = initialize_model(n=5)\n",
    "best_observed.append(best_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now ready to run the BO loop (this make take a few minutes, depending on your machine)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Running BO .................................................."
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "print(f\"\\nRunning BO \", end='')\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# run N_BATCH rounds of BayesOpt after the initial random batch\n",
    "for iteration in range(N_BATCH):    \n",
    "\n",
    "    # fit the model\n",
    "    fit_gpytorch_model(mll)\n",
    "\n",
    "    # define the qNEI acquisition module using a QMC sampler\n",
    "    qmc_sampler = SobolQMCNormalSampler(num_samples=MC_SAMPLES, seed=seed)\n",
    "    qEI = qExpectedImprovement(model=model, sampler=qmc_sampler, best_f=best_value)\n",
    "\n",
    "    # optimize and get new observation\n",
    "    new_x, new_obj = optimize_acqf_and_get_observation(qEI)\n",
    "\n",
    "    # update training points\n",
    "    train_x = torch.cat((train_x, new_x))\n",
    "    train_obj = torch.cat((train_obj, new_obj))\n",
    "\n",
    "    # update progress\n",
    "    best_value = score_image_recognition(decode(train_x)).max().item()\n",
    "    best_observed.append(best_value)\n",
    "\n",
    "    # reinitialize the model so it is ready for fitting on next iteration\n",
    "    model.set_train_data(train_x, train_obj, strict=False)\n",
    "    \n",
    "    print(\".\", end='')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "EI recommends the best point observed so far. We can visualize what the images corresponding to recommended points *would have* been if the BO process ended at various times. Here, we show the progress of the algorithm by examining the images at 0%, 10%, 25%, 50%, 75%, and 100% completion. The first image is the best image found through the initial random batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAy8AAACRCAYAAADKBYeoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQMUlEQVR4nO3d349kZZkH8OdlmhkRNXF2DRIlixdkZbwymdkZaDCCSGaNieFmBRPExIQbN2LkwnHXP4Ab5GpvSCRjjAIbNZEYg7CIjIsMzExiWMQg7iZmMCiLXsyGHzPTw7sXXWqfM6e7urpO1Xnfrs8nId1PVfU5T536QvHknLcq5ZwDAACgdBcM3QAAAMBmGF4AAIAqGF4AAIAqGF4AAIAqGF4AAIAqGF4AAIAqTDW8pJQOppReSCn9JqV0qK+mqJtc0EUu6CIXdJELusgFERFpq9/zklLaERG/joiPR8RLEXEsIm7JOT/fX3vURi7oIhd0kQu6yAVd5II/W5rib/8hIn6Tc/6fiIiU0gMR8amIWDdES0tLedeuXVPskqG8/vrrKznnCzfxULlYIHJBF7mgi1zQZVa5kIm6vf7666/mnN/Tdd80w8v7IuLkmvqliNi/0R/s2rUrrrzyyil2yVBOnDhxZpMPlYsFIhd0kQu6yAVdZpULmajbiRMnfrvefdMML5uSUro9Im6PiNi5c+esd0cl5IIuckEXuaCLXNAmE4thmgX7v4uIy9bU7x/d1pBzvjfnvDfnvHdpaeazEsOTC7rIBV3kgi5yQZexuZCJxTDN8HIsIq5IKX0gpbQzIm6OiIf6aYuKyQVd5IIuckEXuaCLXBARU1w2lnNeSSn9c0T8OCJ2RMR9Oedf9tYZVZILusgFXeSCLnJBF7ngz6Y6p5Zz/lFE/KinXtgm5IIuckEXuaCLXNBFLoiY8ksqAQAA5sXwAgAAVMHwAgAAVMHwAgAAVMHwAgAAVMHwAgAAVMHwAgAAVMHwAgAAVMHwAgAAVGFp6Abm6cknn2zUu3btmuv+9+7dO9f9sTnHjx8fdP9yUSa5oMvQuWiTkzLIBW3jMuE12jpnXgAAgCoYXgAAgCoYXgAAgCos1JqXEydONOqrr756oE6Yp2eeeaZRX3BBWTN7+7pY18HOh1zQpbS1C+PIyXzIBW3TZqLvTO3fv79Rnzt3rtftl6Ssd2sAAIB1GF4AAIAqGF4AAIAqLNSaly9+8Ysb3t/39YeuMS1DaWsZKINcEFHfWoZxrHXoh1zQNutMnDlzplHv3Llzor9/+umn+2xnarPMmHdvAACgCoYXAACgCoYXAACgCgu15mWc9vV5k17fuLKy0mc79GTa13XS7Y/TXmvxoQ99qM922CS5YBE88MADjfrmm28eqBNKIhfzN+0akO22DmsazrwAAABVMLwAAABVMLwAAABVsOalRwcOHBi6BTZh6M+3f+uttxr1c889N1AnrCUXi2HS68b37dvXqHPOfbYzsaWl5tv20aNHN3y8tQybIxe0pZQmevys30NmvU6zJs68AAAAVTC8AAAAVTC8AAAAVbDmBeasfZ3quXPnGvX+/fvn2Q6FkIv5aF83vmPHjkbdPu6laX+f2OnTpxv12bNn59nOtiEXtLXXMQ29LrJ08zw+zrwAAABVGDu8pJTuSym9klJ6bs1tu1NKj6aUXhz9fPds26Q0ckEXuaCLXNBFLugiF4yzmTMvhyPiYOu2QxHxWM75ioh4bFSzWA6HXHC+wyEXnO9wyAXnOxxywfkOh1ywgbFrXnLOR1JKl7du/lREfHT0+zcj4qcR8ZUe+6pS+5r17Xx9pFys7+KLL27UTzzxxIaPb19bXTO5WJ9cpMtbNxeRi6HXMhw5cqRRv/3tb2/UDz74YKP+9Kc/3ajb3w+0vLzcY3ezJRfrk4t0eevmInIxpHl/r8ubb77ZqK+55pq57n8jW13zcknO+eXR77+PiEt66oe6yQVd5IIuckEXuaCLXPAXUy/Yz6sfx7DuV8umlG5PKR1PKR1vfxoG25dc0EUu6CIXdJELumyUC5lYDFsdXv6QUro0ImL085X1HphzvjfnvDfnvHdpySczb3NyQRe5oItc0EUu6LKpXMjEYtjqK/tQRNwWEXeNfv6gt44GNO/rCbehbZmLSY1by9C2nddGjchFyEUHuYjz1zK0tdcytN155519tlMCuQi56LBwuej7/0m303vKZj4q+f6IeCoi/j6l9FJK6fOxGp6Pp5RejIgbRjULRC7oIhd0kQu6yAVd5IJxNvNpY7esc9fHeu6FisgFXeSCLnJBF7mgi1wwztQL9gEAAObBaqYenT59eugWGMAPf/jDiR7/8MMPz6gTSiIXzMPPfvazoVugQHJRH2tcNs+ZFwAAoAqGFwAAoAqGFwAAoArWvPRoeXl56BaYg0mvSz137lyj/trXvtZnOxRCLgDYLN8tuHXOvAAAAFUwvAAAAFUwvAAAAFWw5mUK999//0SPX1pqHu6VlZU+26EnKaVGfezYsam2t2PHjqn+njLIBX34yU9+0uv22tfNb+fvdtjO5IK+becMOPMCAABUwfACAABUwfACAABUwZqXKdx9992Nuu/P7D516lSjvv7663vdPqt2797dqB955JGZ7m/o61AvuuiiRv3GG2/Mdf+1kAu5mIV3vvOdM93+0Dlia+Ri8eScG3V7XWXfPvOZzzTq73znOzPd3yw58wIAAFTB8AIAAFTB8AIAAFTBmpc12teAfu5zn2vUhw8fbtR9r3Fpe9e73jXT7bPqtddea9Tt79959dVXG/UnP/nJibbfzkn7Ote+TZrLgwcPNur2811UciEXs7Bv374N75/1+wplkovFM+/X/Mtf/nKjtuYFAABgxgwvAABAFQwvAABAFax52UB7jQvb0+nTpxv1gQMHet3+8vLyhvub1LXXXtuo77nnnqm29/DDDzdqn/+/Si7kYgjt47xr165G/eSTT060Pd/vsT3IxeJpvyZ9r4Fpb6/v96RZcuYFAACoguEFAACoguEFAACogjUvE/A562xF39eNTruWgTLIxTCeeOKJRn3xxRdv+Pih1wK0c9LuZ2mp+TZ+9OjRDbdnrUM3uZCL0rVfk5RSo37mmWc2vH+c9rqpkjPgzAsAAFAFwwsAAFAFwwsAAFAFa14G9NZbbzXqCy5ozpJXX331PNuhUNZa0UUutmbcWoa20tcCjFvLME77+d14442N+k9/+tNU26+FXDTJRXne+973Nurvfve7jXrSNS7jlJxxZ14AAIAqjB1eUkqXpZQeTyk9n1L6ZUrpjtHtu1NKj6aUXhz9fPfs26UUckEXuaCLXNBFLugiF4yzmTMvKxFxZ855T0QciIgvpJT2RMShiHgs53xFRDw2qlkcckEXuaCLXNBFLugiF2xo7JqXnPPLEfHy6Pf/Syn9KiLeFxGfioiPjh72zYj4aUR8ZSZdFuL5559v1Hv27GnUJ0+ebNQ33XTThtu77rrrGvXjjz8+RXfzJRezU/NaBrmYHbnoR/u67XHH9c0335xo++3vWmivZSzdI4880qjPnDnTqPtciykX9VjUXAyp5v/mz9pE//aklC6PiA9HxNMRcckoYBERv4+IS3rtjGrIBV3kgi5yQRe5oItc0GXTw0tK6R0R8b2I+FLO+dTa+3LOOSLyOn93e0rpeErp+MrKylTNUh65oItc0EUu6CIXdNlKLmRiMWxqeEkpXRirAfp2zvn7o5v/kFK6dHT/pRHxStff5pzvzTnvzTnvXVryyczbiVzQRS7oIhd0kQu6bDUXMrEYxr6yafWDo78REb/KOX99zV0PRcRtEXHX6OcPZtJhQT772c/2ur2a1ri0LXIu2p+lfuzYsYE66ccf//jH3rYlF38lF39Vci6m/e6CAwcONOra1jKMs3PnzpltWy7qtYi5aD/nD37wg4362WefnWh7P//5zzfc/tBK+l6Xts2MpcsRcWtE/FdK6Rej2/4lVsPz7ymlz0fEbyPin2bTIoWSC7rIBV3kgi5yQRe5YEOb+bSx/4yI9b6282P9tkMt5IIuckEXuaCLXNBFLhhne53HBAAAti2rmaBlu3+2+vXXX9+oT506tc4jWUsu6HL06NGhW5ipkq97L5lc1O/IkSONur1u6W1ve9s825m5ml5TZ14AAIAqGF4AAIAqGF4AAIAqWPMChbnpppsa9cmTJwfqhJLIRR3a140/9dRTjfrCCy+cZzvn2bdvX6Ne/aJyZk0u6vORj3ykUX/rW99q1FdeeeU825laTWtaxnHmBQAAqILhBQAAqILhBQAAqII1L9Ay6XWhd911V6O+4YYbGvVVV13VqM+ePbu1xhiUXLAV7dcZIuSiRrfeeuuG91900UWNuv0esLy83Gs/22kNy6SceQEAAKpgeAEAAKpgeAEAAKpgzQtM6dChQ0O3QIHkAmBxvPHGG436jjvuGKiT7c+ZFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAqGFwAAoAop5zy/naX0vxHx24j424h4dW47npz+zvd3Oef3zGLDctGLoXqTi7L7k4thlNxbxPbNxWvhuE9jW+Wikv9WROhvPevmYq7Dy192mtLxnPPeue94k/Q3jNKfV8n9ldzbtEp/biX3V3Jv0yr5uZXcW0T5/W1V6c9Lf8Mo/Xnpb3IuGwMAAKpgeAEAAKow1PBy70D73Sz9DaP051VyfyX3Nq3Sn1vJ/ZXc27RKfm4l9xZRfn9bVfrz0t8wSn9e+pvQIGteAAAAJuWyMQAAoApzHV5SSgdTSi+klH6TUjo0z32v0899KaVXUkrPrbltd0rp0ZTSi6Of7x6wv8tSSo+nlJ5PKf0ypXRHaT32QS4m7k8uhulHLgogFxP3JxfD9CMXBZCLifurIhdzG15SSjsi4t8i4h8jYk9E3JJS2jOv/a/jcEQcbN12KCIeyzlfERGPjeqhrETEnTnnPRFxICK+MDpmJfU4FbnYErkYxuGQi0HJxZbIxTAOh1wMSi62pI5c5Jzn8k9EXBURP15TfzUivjqv/W/Q1+UR8dya+oWIuHT0+6UR8cLQPa7p7QcR8fGSe5QLuZCL4Y+fXMiFXMiFXMjFds3FPC8be19EnFxTvzS6rTSX5JxfHv3++4i4ZMhm/iyldHlEfDgino5Ce9wiuZiCXAyuyGMuF4Mr8pjLxeCKPOZyMbgij3nJubBgfwN5dcQc/OPYUkrviIjvRcSXcs6n1t5XSo+LpJRjLhdlKeWYy0VZSjnmclGWUo65XJSllGNeei7mObz8LiIuW1O/f3Rbaf6QUro0ImL085Uhm0kpXRirAfp2zvn7o5uL6nFKcrEFclGMoo65XBSjqGMuF8Uo6pjLRTGKOuY15GKew8uxiLgipfSBlNLOiLg5Ih6a4/4366GIuG30+22xer3fIFJKKSK+ERG/yjl/fc1dxfTYA7mYkFwUpZhjLhdFKeaYy0VRijnmclGUYo55NbmY88KfT0TEryPivyPiX4dc7DPq5/6IeDkizsbqtZCfj4i/idVPUngxIv4jInYP2N81sXpq7tmI+MXon0+U1KNcyIVclHHM5UIu5EIu5EIuFiEXadQsAABA0SzYBwAAqmB4AQAAqmB4AQAAqmB4AQAAqmB4AQAAqmB4AQAAqmB4AQAAqmB4AQAAqvD/gOhs3QiSROMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1008x1008 with 6 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 6, figsize=(14, 14))\n",
    "percentages = np.array([0, 10, 25, 50, 75, 100], dtype=np.float32)\n",
    "inds = (N_BATCH * BATCH_SIZE * percentages / 100 + 4).astype(int)\n",
    "\n",
    "for i, ax in enumerate(ax.flat):\n",
    "    b = torch.argmax(score_image_recognition(decode(train_x[:inds[i],:])), dim=0)\n",
    "    img = decode(train_x[b].view(1, -1)).squeeze().cpu()\n",
    "    ax.imshow(img, alpha=0.8, cmap='gray')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
